import pandas as pd
import numpy as np
from sklearn.preprocessing import OneHotEncoder, StandardScaler
from doubleml import DoubleMLClusterData, DoubleMLData


def check_missing_vals(df: pd.DataFrame, *args) -> None:
    for col in args:
        print(
            f"Checking missing data in columns... Columns {col} have missing entries: ",
            df[col].isnull().sum(),
        )


def impute_missing_data(
    df: pd.DataFrame, missing_cols: list[str], fillna_name="unseen"
) -> None:
    """
    convert nan to unknown as a categorical variable in data
    Args:
        df: input dataframe
        missing_cols: potential missing columns
        fillna_name: the value used to fill NAs
    """

    # keep only the (potentially) missing columns that exist in the df
    existing_columns = list(set(missing_cols).intersection(df.columns))

    df_cols_nan = df[existing_columns].columns[df[existing_columns].isna().any()]
    for col in df_cols_nan:
        if df[col].dtype != "category":
            df[col] = df[col].astype("category")
        df.loc[:, col] = df[col].cat.add_categories(fillna_name).fillna(fillna_name)

    print("NAs after filling nas: ", df[existing_columns].isna().any())


def add_mean_encoding(
    df: pd.DataFrame,
    cont_vars: list[str],
    cols_to_encode: list[str] = ["assignee_id_j", "assignee_id_i"],
) -> pd.DataFrame:
    """
    use means encoding of cols_to_encode (default: assignee_id_j and assignee_id_i) with means of all continuous variables to reduce dimensions
    mean_encoding: use target encoding for firm FEs to reduce dimension. Does not see in literature.
    https://arxiv.org/pdf/1908.09874.pdf

    Args:
        cont_vars: continuous variable names for mean encoding
        cols_to_encode: columns to encode
    Returns:
        df: dataframe with additional encoded columns
    """
    for col in cols_to_encode:
        if col not in df.columns:
            raise ValueError(
                f"Column {col} does not exist. Please check the column string."
            )

    res = []
    for col in cols_to_encode:
        df_mean = (
            df[cont_vars + [col]]
            .groupby(col)
            .transform("mean")
            .add_prefix("theta_" + col + "_")
        )
        res.append(df_mean)
        # original_dim = df[col].nunique()
        # print(
        #     f"mean encoding reduces dimension from {original_dim} to {len(df_mean.columns)} for column variable {col}"
        # )

    df_encoded = pd.concat(res, axis=1)
    print("mean encoded feature array: ", df_encoded.shape)
    del res
    df.drop(columns=cols_to_encode, inplace=True)
    return pd.concat([df, df_encoded], axis=1)


def prepare_dml_data(
    df: pd.DataFrame,
    cate_vars: list[str],
    cont_vars: list[str],
    cluster_data: bool = False,
):
    """
    prepare doubleML data
    Args:
        df:
        cate_vars:
        cont_vars:
        cluster_data: prepare doubleML cluster data
        https://docs.doubleml.org/stable/examples/py_double_ml_multiway_cluster.html
    Returns:
        dml_data: DoubleMLData object ready to feed into DML algo. (DoubleMLData | DoubleMLClusterData)
    """

    W_cat = df[cate_vars]
    # require sklearn version 1.3.2
    enc = OneHotEncoder(handle_unknown="ignore", sparse=False, dtype="float64")
    print(f"encoding {len(cate_vars)} categorical variables.")
    W_cat_tr = enc.fit_transform(W_cat)
    print(W_cat_tr.shape)

    scaler = StandardScaler()
    print(f"encoding {len(cont_vars)} continuous variables.")

    cols_mean_encoding = [col for col in df.columns if col.startswith("theta_")]
    W_cont_tr = scaler.fit_transform(df[cont_vars + cols_mean_encoding])
    print(W_cont_tr.shape)

    Y = df["omission"]
    T = df["allfemale_09_100_j"]
    X = df[["allfemale_09_100_i"]]
    W = np.concatenate([W_cont_tr, W_cat_tr], axis=1)

    dml_data = DoubleMLData.from_arrays(x=np.concatenate([X, W], axis=1), y=Y, d=T)
    print("total number of variables in DML data: ", len(dml_data.all_variables))

    if cluster_data:
        C = df["patent_id_i"]
        dml_data = DoubleMLClusterData.from_arrays(
            x=np.concatenate([X, W], axis=1), y=Y, d=T, cluster_vars=C
        )
        print("total number of variables in DML data: ", len(dml_data.all_variables))

    return dml_data
